{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Natural gradients\n",
    "\n",
    "This shows some basic usage of the natural gradient optimizer, both on its own and in combination with Adam optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-13T13:31:56.149218Z",
     "start_time": "2018-06-13T13:31:55.213980Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "import numpy as np\n",
    "import gpflow\n",
    "from gpflow.test_util import notebook_niter, notebook_range\n",
    "from gpflow.models import VGP, GPR, SGPR, SVGP\n",
    "from gpflow.training import NatGradOptimizer, AdamOptimizer, XiSqrtMeanVar\n",
    "\n",
    "%matplotlib inline\n",
    "%precision 4\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "N, D = 100, 2\n",
    "\n",
    "# inducing points\n",
    "M = 10\n",
    "\n",
    "X = np.random.uniform(size=(N, D))\n",
    "Y = np.sin(10 * X)\n",
    "Z = np.random.uniform(size=(M, D))\n",
    "adam_learning_rate = 0.01\n",
    "iterations = notebook_niter(5)\n",
    "    \n",
    "\n",
    "def make_matern_kernel():\n",
    "    return gpflow.kernels.Matern52(D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### \"VGP is a GPR\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we will demonstrate how natural gradients can turn VGP into GPR in a *single step, if the likelihood is Gaussian*."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by first creating a standard GPR model with Gaussian likelihood:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-30T09:59:57.799754Z",
     "start_time": "2018-04-30T09:59:57.547423Z"
    }
   },
   "outputs": [],
   "source": [
    "gpr = GPR(X, Y, kern=make_matern_kernel())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The likelihood of the exact GP model is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-231.0899"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gpr.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will create an approximate model which approximates the true posterior via a variational Gaussian distribution.<br>We initialize the distribution to be zero mean and unit variance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-30T09:59:57.799754Z",
     "start_time": "2018-04-30T09:59:57.547423Z"
    }
   },
   "outputs": [],
   "source": [
    "vgp = VGP(X, Y, kern=make_matern_kernel(), likelihood=gpflow.likelihoods.Gaussian())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The likelihood of the approximate GP model is:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-328.8438"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vgp.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Obviously, our initial guess for the variational distribution is not correct, which results in a lower bound to the likelihood of the exact GPR model. We can optimize the variational parameters in order to get a tighter bound. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In fact, we only need to take **1 step** in the natural gradient direction to recover the exact posterior:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-30T09:59:58.254414Z",
     "start_time": "2018-04-30T09:59:57.864617Z"
    }
   },
   "outputs": [],
   "source": [
    "natgrad_optimizer = NatGradOptimizer(gamma=1.)\n",
    "natgrad_tensor = natgrad_optimizer.make_optimize_tensor(vgp, var_list=[(vgp.q_mu, vgp.q_sqrt)])\n",
    "\n",
    "session = gpflow.get_default_session()\n",
    "session.run(natgrad_tensor)\n",
    "\n",
    "# update the cache of the variational parameters in the current session\n",
    "vgp.anchor(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The likelihood of the approximate GP model after a single natgrad step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-231.0906"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vgp.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimize both variational parameters and kernel hyperparameters together\n",
    "\n",
    "In the Gaussian likelihood case we can iterate between an Adam update for the hyperparameters and a NatGrad update for the variational parameters. That way, we achieve optimization of hyperparameters as if the model were a GPR."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The trick is to forbid Adam from updating the variational parameters by setting them to not trainable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Stop Adam from optimizing the variational parameters\n",
    "vgp.q_mu.trainable = False\n",
    "vgp.q_sqrt.trainable = False\n",
    "\n",
    "# Create Adam tensors for each model\n",
    "adam_for_vgp_tensor = AdamOptimizer(learning_rate=adam_learning_rate).make_optimize_tensor(vgp)\n",
    "adam_for_gpr_tensor = AdamOptimizer(learning_rate=adam_learning_rate).make_optimize_tensor(gpr)\n",
    "\n",
    "variational_params = [(vgp.q_mu, vgp.q_sqrt)]\n",
    "natgrad_tensor = NatGradOptimizer(gamma=1.).make_optimize_tensor(vgp, var_list=variational_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPR with Adam: iteration 1 likelihood -230.6706\n",
      "GPR with Adam: iteration 2 likelihood -230.2508\n",
      "GPR with Adam: iteration 3 likelihood -229.8303\n",
      "GPR with Adam: iteration 4 likelihood -229.4093\n",
      "GPR with Adam: iteration 5 likelihood -228.9876\n"
     ]
    }
   ],
   "source": [
    "for i in range(iterations):\n",
    "    session.run(adam_for_gpr_tensor)\n",
    "    iteration = i + 1\n",
    "    likelihood = session.run(gpr.likelihood_tensor)\n",
    "    print(f'GPR with Adam: iteration {iteration} likelihood {likelihood:.04f}')\n",
    "\n",
    "# Update the cache of the parameters in the current session\n",
    "gpr.anchor(session)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGP with natural gradients and Adam: iteration 1 likelihood -230.6713\n",
      "VGP with natural gradients and Adam: iteration 2 likelihood -230.2514\n",
      "VGP with natural gradients and Adam: iteration 3 likelihood -229.8310\n",
      "VGP with natural gradients and Adam: iteration 4 likelihood -229.4099\n",
      "VGP with natural gradients and Adam: iteration 5 likelihood -228.9882\n"
     ]
    }
   ],
   "source": [
    "for i in range(iterations):\n",
    "    session.run(adam_for_vgp_tensor)\n",
    "    session.run(natgrad_tensor)\n",
    "    iteration = i + 1\n",
    "    likelihood = session.run(vgp.likelihood_tensor)\n",
    "    print(f'VGP with natural gradients and Adam: iteration {iteration} likelihood {likelihood:.04f}')\n",
    "\n",
    "# We need to alter their trainable status in order to correctly anchor them in the current session\n",
    "vgp.q_mu.trainable = True\n",
    "vgp.q_sqrt.trainable = True\n",
    "\n",
    "# Update the cache of the parameters (including the variational) in the current session\n",
    "vgp.anchor(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare GPR and VGP lengthscales after optimisation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-30T09:59:59.505965Z",
     "start_time": "2018-04-30T09:59:59.503129Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPR lengthscales = 0.9686\n",
      "VGP lengthscales = 0.9686\n"
     ]
    }
   ],
   "source": [
    "print(f'GPR lengthscales = {gpr.kern.lengthscales.value:.04f}')\n",
    "print(f'VGP lengthscales = {vgp.kern.lengthscales.value:.04f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Natural gradients also work for the sparse model\n",
    "Similarly, natural gradients turn SVGP into SGPR in the Gaussian likelihood case. <br>\n",
    "We can again combine natural gradients with Adam to update both variational parameters and hyperparameters too.<br>\n",
    "Here we'll just do a single natural step demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-30T09:59:59.946016Z",
     "start_time": "2018-04-30T09:59:59.507143Z"
    }
   },
   "outputs": [],
   "source": [
    "svgp = SVGP(X, Y, kern=make_matern_kernel(), likelihood=gpflow.likelihoods.Gaussian(), Z=Z)\n",
    "sgpr = SGPR(X, Y, kern=make_matern_kernel(), Z=Z)\n",
    "\n",
    "for model in svgp, sgpr:\n",
    "    model.likelihood.variance = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Analytically optimal sparse model likelihood:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-281.6273"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sgpr.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SVGP likelihood before natural gradient step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1404.0805"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svgp.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-30T10:00:00.691979Z",
     "start_time": "2018-04-30T10:00:00.053823Z"
    }
   },
   "outputs": [],
   "source": [
    "natgrad_tensor = NatGradOptimizer(gamma=1.).make_optimize_tensor(svgp, var_list=[(svgp.q_mu, svgp.q_sqrt)])\n",
    "session = gpflow.get_default_session()\n",
    "session.run(natgrad_tensor)\n",
    "\n",
    "\n",
    "# Update the cache of the variational parameters in the current session\n",
    "svgp.anchor(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SVGP likelihood after a single natural gradient step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-281.6273"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svgp.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Minibatches\n",
    "A crucial property of the natural gradient method is that it still works with minibatches.\n",
    "In practice though, we need to use a smaller gamma."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-30T10:00:01.667254Z",
     "start_time": "2018-04-30T10:00:00.692973Z"
    }
   },
   "outputs": [],
   "source": [
    "svgp = SVGP(X, Y, kern=make_matern_kernel(),\n",
    "            likelihood=gpflow.likelihoods.Gaussian(), Z=Z, minibatch_size=50)\n",
    "svgp.likelihood.variance = 0.1\n",
    "\n",
    "variational_params = [(svgp.q_mu, svgp.q_sqrt)]\n",
    "natgrad = NatGradOptimizer(gamma=.1)\n",
    "natgrad_tensor = natgrad.make_optimize_tensor(svgp, var_list=variational_params)\n",
    "\n",
    "for _ in range(notebook_niter(100)):\n",
    "    session.run(natgrad_tensor)\n",
    "\n",
    "svgp.anchor(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Minibatch SVGP likelihood after NatGrad optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-282.2219"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.average([svgp.compute_log_likelihood() for _ in notebook_range(1000)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison with ordinary gradients in the conjugate case\n",
    "\n",
    "##### (Take home message: natural gradients are always better)\n",
    "\n",
    "Compared to SVGP with ordinary gradients with minibatches, the natural gradient optimizer is much faster in the Gaussian case. \n",
    "\n",
    "Here we'll do hyperparameter learning together with optimization of the variational parameters, comparing the interleaved natural gradient approach and the one using ordinary gradients for the hyperparameters and variational parameters jointly.\n",
    "\n",
    "Note that again we need to compromise for smaller gamma value, which we'll keep *fixed* during the optimisation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-13T13:32:06.105670Z",
     "start_time": "2018-06-13T13:32:02.550268Z"
    }
   },
   "outputs": [],
   "source": [
    "svgp_ordinary = SVGP(X, Y,\n",
    "                     kern=make_matern_kernel(),\n",
    "                     likelihood=gpflow.likelihoods.Gaussian(),\n",
    "                     Z=Z,\n",
    "                     minibatch_size=50)\n",
    "\n",
    "svgp_natgrad = SVGP(X, Y,\n",
    "                    kern=make_matern_kernel(),\n",
    "                    likelihood=gpflow.likelihoods.Gaussian(),\n",
    "                    Z=Z,\n",
    "                    minibatch_size=50)\n",
    "\n",
    "# ordinary gradients with Adam for SVGP\n",
    "adam = AdamOptimizer(adam_learning_rate)\n",
    "adam_for_svgp_ordinary_tensor = adam.make_optimize_tensor(svgp_ordinary)\n",
    "\n",
    "# NatGrads and Adam for SVGP\n",
    "# Stop Adam from optimizing the variational parameters\n",
    "svgp_natgrad.q_mu.trainable = False\n",
    "svgp_natgrad.q_sqrt.trainable = False\n",
    "\n",
    "# Create the optimize_tensors for SVGP\n",
    "adam = AdamOptimizer(adam_learning_rate)\n",
    "adam_for_svgp_natgrad_tensor = adam.make_optimize_tensor(svgp_natgrad)\n",
    "\n",
    "natgrad = NatGradOptimizer(gamma=.1)\n",
    "variational_params = [(svgp_natgrad.q_mu, svgp_natgrad.q_sqrt)]\n",
    "natgrad_tensor = natgrad.make_optimize_tensor(svgp_natgrad, var_list=variational_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's optimise the models now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimize svgp_ordinary\n",
    "for _ in range(notebook_niter(100)):\n",
    "    session.run(adam_for_svgp_ordinary_tensor)\n",
    "\n",
    "svgp_ordinary.anchor(session)\n",
    "\n",
    "# Optimize svgp_natgrad\n",
    "for _ in range(notebook_niter(100)):\n",
    "    session.run(adam_for_svgp_natgrad_tensor)\n",
    "    session.run(natgrad_tensor)\n",
    "\n",
    "svgp_natgrad.anchor(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SVGP likelihood after ordinary Adam optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-207.4970"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.average([svgp_ordinary.compute_log_likelihood() for _ in notebook_range(1000)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SVGP likelihood after NatGrad and Adam optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-197.0681"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.average([svgp_natgrad.compute_log_likelihood() for _ in notebook_range(1000)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison with ordinary gradients in the non-conjugate case\n",
    "#### Binary classification\n",
    "\n",
    "##### (Take home message: natural gradients are usually better)\n",
    "\n",
    "We can use nat grads even when the likelihood isn't Gaussian. It isn't guaranteed to be better, but it usually is better in practical situations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-13T13:32:13.397570Z",
     "start_time": "2018-06-13T13:32:10.157801Z"
    }
   },
   "outputs": [],
   "source": [
    "Y_binary = np.random.choice([1., -1], size=X.shape)\n",
    "\n",
    "vgp_bernoulli = VGP(X, Y_binary, kern=make_matern_kernel(), likelihood=gpflow.likelihoods.Bernoulli())\n",
    "vgp_bernoulli_natgrad = VGP(X, Y_binary, kern=make_matern_kernel(), likelihood=gpflow.likelihoods.Bernoulli())\n",
    "\n",
    "# ordinary gradients with Adam for VGP with Bernoulli likelihood\n",
    "adam = AdamOptimizer(adam_learning_rate)\n",
    "adam_for_vgp_bernoulli_tensor = adam.make_optimize_tensor(vgp_bernoulli)\n",
    "\n",
    "# NatGrads and Adam for VGP with Bernoulli likelihood\n",
    "# Stop Adam from optimizing the variational parameters\n",
    "vgp_bernoulli_natgrad.q_mu.trainable = False\n",
    "vgp_bernoulli_natgrad.q_sqrt.trainable = False\n",
    "\n",
    "# Create the optimize_tensors for VGP with natural gradients\n",
    "adam = AdamOptimizer(adam_learning_rate)\n",
    "adam_for_vgp_bernoulli_natgrad_tensor = adam.make_optimize_tensor(vgp_bernoulli_natgrad)\n",
    "\n",
    "natgrad = NatGradOptimizer(gamma=.1)\n",
    "variational_params = [(vgp_bernoulli_natgrad.q_mu, vgp_bernoulli_natgrad.q_sqrt)]\n",
    "natgrad_tensor = natgrad.make_optimize_tensor(vgp_bernoulli_natgrad, var_list=variational_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimize vgp_bernoulli\n",
    "for _ in range(notebook_niter(100)):\n",
    "    session.run(adam_for_vgp_bernoulli_tensor)\n",
    "\n",
    "vgp_bernoulli.anchor(session)\n",
    "\n",
    "# Optimize vgp_bernoulli_natgrad\n",
    "for _ in range(notebook_niter(100)):\n",
    "    session.run(adam_for_vgp_bernoulli_natgrad_tensor)\n",
    "    session.run(natgrad_tensor)\n",
    "\n",
    "vgp_bernoulli_natgrad.anchor(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "VGP likelihood after ordinary Adam optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-146.1206"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vgp_bernoulli.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "VGP likelihood after NatGrad + Adam optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-143.9411"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vgp_bernoulli_natgrad.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also choose to run natural gradients in another parameterization.<br>\n",
    "The sensible choice is the model parameters (q_mu, q_sqrt), which is already in gpflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "vgp_bernoulli_natgrads_xi = VGP(X, Y_binary,\n",
    "                                kern=make_matern_kernel(),\n",
    "                                likelihood=gpflow.likelihoods.Bernoulli())\n",
    "\n",
    "var_list = [(vgp_bernoulli_natgrads_xi.q_mu, vgp_bernoulli_natgrads_xi.q_sqrt, XiSqrtMeanVar())]\n",
    "\n",
    "# Stop Adam from optimizing the variational parameters\n",
    "vgp_bernoulli_natgrads_xi.q_mu.trainable = False\n",
    "vgp_bernoulli_natgrads_xi.q_sqrt.trainable = False\n",
    "\n",
    "# Create the optimize_tensors for VGP with Bernoulli likelihood\n",
    "adam = AdamOptimizer(adam_learning_rate)\n",
    "adam_for_vgp_bernoulli_natgrads_xi_tensor = adam.make_optimize_tensor(vgp_bernoulli_natgrads_xi)\n",
    "\n",
    "natgrad = NatGradOptimizer(gamma=.01)\n",
    "natgrad_tensor = natgrad.make_optimize_tensor(vgp_bernoulli_natgrads_xi, var_list=var_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimize vgp_bernoulli_natgrads_xi\n",
    "for _ in range(notebook_niter(100)):\n",
    "    session.run(adam_for_vgp_bernoulli_natgrads_xi_tensor)\n",
    "    session.run(natgrad_tensor)\n",
    "\n",
    "vgp_bernoulli_natgrads_xi.anchor(session)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "VGP likelihood after NatGrads with XiSqrtMeanVar + Adam optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-143.9014"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vgp_bernoulli_natgrads_xi.compute_log_likelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With sufficiently small steps, it shouldn't make a difference which transform is used, but for large \n",
    "steps this can make a difference in practice."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
